from Libraries import *


def str_atk_tg_for_syn(net,x,target,criterion,eps1,eps2,rho_0,eta,lamda,sigma = 1e-2,iterations = 100):
    model_v2= deepcopy(net)
    model   = deepcopy(net).to(device)
    A       = model.state_dict()['fc1.weight'].data
    B       = model.state_dict()['fc2.weight'].data
    x,target = x.to(device),target.to(device)
    
    #Initializations:
    delta_a = torch.zeros_like(A).to(device) #Layer perturbation
    delta   = torch.zeros_like(x).to(device) #Input perturbation
    w       = torch.zeros_like(x).to(device) #Input perturbation
    z       = torch.zeros_like(x).to(device) #Input perturbation
    u       = torch.zeros_like(x).to(device) #First  dual variable
    v       = torch.zeros_like(x).to(device) #Second dual variable
    sum_b   = torch.zeros_like(torch.eye(B.shape[1])).to(device) #Summation of diag(B)^2
    rho     = 1.01*rho_0 
    lr      = 0.1
    eq_con  = []

    #sum of diag(B)**2
    for j in range(B.shape[0]):
        sum_b += torch.diag(B[j,:])**2
    
    #The attack will start here
    for i in range(iterations):
        
        if (i+1)%50 == 0: #Increasing penalty after 50 iterations
            lamda = lamda * 1.01
            rho     = 1.01*rho_0

        
        #Updating delta
        temp1   = 2*lamda*torch.mm(A.t(),A) + (2+rho)*torch.eye(A.shape[1]).to(device) #This is the A in Ax = B
        temp2   = 2*lamda*torch.mm(torch.mm(A.t(),delta_a),x.squeeze().reshape(-1,1))
        temp2   = temp2.reshape(z.shape) + rho*z.squeeze() - u.squeeze() # This is the    B in Ax = B
        delta,_ = torch.solve(temp2.reshape(-1,1),temp1)
        delta   = delta.reshape(z.shape)
        
        #Updating w
        w    = z - v/rho #This is originally the a in the overleaf file, but it can be implemented this way.
        min1 = torch.min(5-x,torch.tensor(eps1).to(device))
        max1 = torch.max(-5-x,torch.tensor(-eps1).to(device))
        ind1 = (w > min1).to(device)
        ind2 = (w < max1).to(device)

        w[ind1] = min1[ind1]
        w[ind2] = max1[ind2]
        #The rest of indicies will not change from the original a

        #updating z                
        perturbed_input = Variable(x + z,requires_grad = True).to(device)
        output = model(perturbed_input).unsqueeze(dim=0)
        loss   = criterion(output,target)
        model.zero_grad()
        #grad_f = torch.autograd.grad(loss,perturbed_input) #The gradient of f with respect to the purterbed input
        
        #This A_grad will be used for updating Delta A
        #model.zero_grad()
        loss.backward()
        A_grad = model.fc1.weight.grad.data
        grad_f = perturbed_input.grad.data
        
        b = delta + u/rho
        c = w     + v/rho
        z = (eta*z + rho*(b+c) - grad_f[0])/(eta + 2*rho)
        
        #Updating delta_a
        temp    = torch.mm(delta_a, - 2*lamda*torch.mm(x.squeeze().reshape(-1,1),x.squeeze().reshape(-1,1).t()))
        temp    = temp - 2*lamda* torch.mm(torch.mm(A,delta.squeeze().reshape(-1,1)),x.squeeze().reshape(-1,1).t())
        temp    = temp - 2*torch.mm(sum_b,delta_a) - A_grad
        delta_a = prox_infty(delta_a + lr*temp,eps2)
        
        #Updating u,v
        u += rho*(delta - z)
        v += rho*(w     - z)
        
        #Check whether the classifier is fooled
        out1 = model(x + delta).unsqueeze(dim=0).argmax(1)
        out2 = model(x + z    ).unsqueeze(dim=0).argmax(1)
        out3 = model(x + w    ).unsqueeze(dim=0).argmax(1)
        if out1 == target or out2 == target or out3 == target:
            rho   = rho*1.1
            eta   = eta*1.01
            #print('gg')
            #Sparsity
            delta[torch.abs(delta) < sigma] = 0
            z[torch.abs(z) < sigma]         = 0
            w[torch.abs(w) < sigma]         = 0

        model_v2.state_dict()['fc1.weight'][:,:] = deepcopy(A+delta_a)

        out0 = model_v2(x     ).unsqueeze(dim=0).argmax(1)
        out1 = model(x + delta).unsqueeze(dim=0).argmax(1)
        out2 = model(x + z    ).unsqueeze(dim=0).argmax(1)
        out3 = model(x + w    ).unsqueeze(dim=0).argmax(1)    
        
        #Checking the linear equality, what is happenning to it
        if i > 0:
            err = torch.norm(torch.mm(A,w.squeeze().reshape(-1,1)) - torch.mm(delta_a,x.squeeze().reshape(-1,1)))
            eq_con.append(err.item())

        if out1 == target and out2 == target and out3 == target and out0 == target:
            print('GG')
            #print(torch.max(delta_x),torch.max(z),torch.max(w),eps1)
            cond1 = round(torch.max(torch.abs(delta)).item(),3) <= round(eps1,3)
            cond2 = round(torch.max(torch.abs(w)).item(),3) <= round(eps1,3)
            cond3 = round(torch.max(torch.abs(z)).item(),3) <= round(eps1,3)
            if cond1 and cond2 and cond3:
                print(f'All conditions are satisfied, and the attack is successful in {i+1} iterations')
                return delta,z,w,delta_a,eq_con
    print('The algorithm did not succeed in attacking the model')
    return delta,z,w,delta_a,eq_con




def prox_infty(A,eps = 1):
    A = A.cpu().numpy()
    for i in range(A.shape[0]): # Projecting each row of A
        A[i,:] = poject_on_l1_ball(A[i,:],eps)
    return torch.from_numpy(A).to(device)


def euclidean_proj_simplex(v, s=1):
    assert s > 0, "Radius s must be strictly positive (%d <= 0)" 
    n, = v.shape  # will raise ValueError if v is not 1-D
    if v.sum() == s and np.alltrue(v >= 0):
        return v
    u = np.sort(v)[::-1]
    cssv = np.cumsum(u)
    rho = np.nonzero(u * np.arange(1, n+1) > (cssv - s))[0][-1]
    theta = (cssv[rho] - s) / (rho + 1.0)
    w = (v - theta).clip(min=0)
    return w


def poject_on_l1_ball(v, s=1):
    assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s
    n, = v.shape  # will raise ValueError if v is not 1-D
    u = np.abs(v)
    if u.sum() <= s:
        return v
    w = euclidean_proj_simplex(u, s=s)
    w *= np.sign(v)
    return w